Skip to content

Conversation

@deepcharm
Copy link
Contributor

@deepcharm deepcharm commented Nov 4, 2025

Compiled Autograd is an extension to torch.compile which enhances the autograd engine by capturing a larger backward computation graph at runtime. This allows a more comprehensive optimization of the backward pass during training.

Overall, 5-20% speedup is expected in backward-heavy workloads with stable graphs.

Disabled by default, the feature can be enabled from a user script by setting compiled_autograd_enabled=True when invoking the engine's compile method.

Note, that bfloat16 + eager backend requires PyTorch >=2.5 (where partial fixes landed) or disabling compiled autograd for bfloat16 models (due to a known PyTorch bug in torch.compile PyTorch #152162/#161153)

Compiled Autograd is an extension to torch.compile which enhances
the autograd engine by capturing a larger backward computation
graph at runtime. This allows a more comprehensive optimization
of the backward pass during training.

Overall, 5-20% speedup is expected in backward-heavy workloads
with stable graphs.

Disabled by default, the feature can be enabled from a user
script by setting 'compiled_autograd_enabled=True' when invoking
the engine's 'compile' method.

Signed-off-by: Max Kovalenko <[email protected]>
@eternalNight
Copy link
Contributor

@deepcharm Thanks for the patch!

Compiled autograd is not compatible with DeepCompile today as it will override the backward graph to which DeepCompile has inserted ZeRO ops. Having both enabled causes a torch._dynamo.exc.InternalTorchDynamoError (IndexError: list index out of range) exception in my local test.

Would you please warn the user and unset self._is_compiled_autograd_enabled if deepcompile and compiled_autograd_enabled?

@deepcharm
Copy link
Contributor Author

@eternalNight Thank you for the good catch! Updated the code per your request. Please let me know if that works.

@eternalNight
Copy link
Contributor

Thanks for the update!

I'm trying different combinations of compile options. When playing with this model (https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3), I got

[rank2]: torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <method 'set_' of 'torch._C.TensorBase' objects>(*(FakeTensor(..., device='cuda:2', size=(32000, 4096)), FakeTensor(..., device='cuda:2', size=(32000, 4096), dtype=torch.bfloat16)), **{}): got RuntimeError('Could not set tensor of type c10::BFloat16 to a tensor of type float')

when using the eager backend with compiled_autograd_enabled=True and deepcompile disabled. With the inductor backend, I got the following error (even though I set TORCHINDUCTOR_AUTOGRAD_CACHE=0):

[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
[rank3]:     torch.autograd.backward(
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_dynamo/compiled_autograd.py", line 1224, in set_node_origin
[rank3]:     raise RuntimeError(
[rank3]: RuntimeError: This compiled backward function was saved by AOTAutogradCache, which does not support
[rank3]:                     compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`.

Are those known issues of compiled_autograd?

@deepcharm
Copy link
Contributor Author

Thanks for the detailed testing—super helpful! These errors match known PyTorch issues with Compiled Autograd + distributed/mixed precision:

  1. Eager/bfloat16: This is a known PyTorch bug in torch.compile (PyTorch #152162/#161153), where Dynamo tries to simulate a tensor operation but fails on a dtype mismatch when a tensor's .data is reassigned after dtype conversion (e.g., float → bfloat16). I've added a note in the PR description.

  2. Inductor cache: Stale AOTAutogradCache entries from prior runs (even with TORCHINDUCTOR_AUTOGRAD_CACHE=0).
    Clearing ~/.cache/torch/inductor should fix it.

Let me know if clearing the cache resolves the second one for you. BTW what's your PyTorch version/setup?

@eternalNight
Copy link
Contributor

eternalNight commented Nov 14, 2025

Thanks for the detailed testing—super helpful! These errors match known PyTorch issues with Compiled Autograd + distributed/mixed precision:

  1. Eager/bfloat16: This is a known PyTorch bug in torch.compile (PyTorch #152162/#161153), where Dynamo tries to simulate a tensor operation but fails on a dtype mismatch when a tensor's .data is reassigned after dtype conversion (e.g., float → bfloat16). I've added a note in the PR description.

I'm using torch 2.7.1, and the model has torch autocast enabled by default, so does that mean compiled autograd should not be used with autocast now?

Using a bf16 model leads me to a different error:

[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatch
er_function
[rank3]:     return _create_aot_dispatcher_function(
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatc
her_function
[rank3]:     fw_metadata = run_functionalized_fw_and_collect_metadata(
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line
197, in inner
[rank3]:     flat_f_outs = f(*flat_f_args)
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line
 899, in functional_call
[rank3]:     out = PropagateUnbackedSymInts(mod).run(
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run
[rank3]:     self.env[node] = self.run_node(node)
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7183, in run_node
[rank3]:     result = super().run_node(n)
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 240, in run_node
[rank3]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 344, in call_method
[rank3]:     return getattr(self_obj, target)(*args_tail, **kwargs)
[rank3]:   File "/mnt/workspaces/venv/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 525, in __torch_dispatch__
[rank3]:     outs_unwrapped = func._op_dk(
[rank3]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank3]: RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: aten::record_stream(Tensor(a!) self, Stream s) -> (). We onl
y support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' i$ a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the sa$
e annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, th$
n delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing $
nd then delete the alias annotation. Otherwise, please file an issue on GitHub.

This is how the graph looks like. The record_stream call is introduced by zero3.

[rank4]: GraphModule: class GraphModule(torch.nn.Module):
[rank4]:     def forward(self, L_param_grad: "bf16[32000, 4096][4096, 1]", L_self_ipg_buckets_torch_bfloat16_buffer: "bf16[200000000][1]"):
[rank4]:         l_param_grad = L_param_grad
[rank4]:         l_self_ipg_buckets_torch_bfloat16_buffer = L_self_ipg_buckets_torch_bfloat16_buffer
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/accelerator/cuda_accelerator.py:118 in stream, code: return torch.cuda.stream(stream)
[rank4]:         stream = torch.cuda.streams.Stream(stream_id = 99, device_index = 4, device_type = 1)
[rank4]:         current_stream = torch.cuda.current_stream(None)
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/runtime/zero/stage3.py:1310 in torch_dynamo_resume_in___add_grad_to_ipg_bucket_at_1305, code: w
ith get_accelerator().stream(self.reduce_and_partition_stream):
[rank4]:         set_stream = torch.cuda.set_stream(stream);  stream = set_stream = None
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/runtime/zero/stage3.py:1318 in torch_dynamo_resume_in___add_grad_to_ipg_bucket_at_1305, code: n
ew_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
[rank4]:         narrow: "bf16[131072000][1]" = l_self_ipg_buckets_torch_bfloat16_buffer.narrow(0, 0, 131072000);  l_self_ipg_buckets_torch_bfloat16_buffe
r = None
[rank4]:         new_grad_tensor: "bf16[32000, 4096][4096, 1]" = narrow.view_as(l_param_grad);  narrow = None
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/runtime/zero/stage3.py:1319 in torch_dynamo_resume_in___add_grad_to_ipg_bucket_at_1305, code: new_grad_tensor.copy_(param.grad, non_blocking=True)
[rank4]:         copy_: "bf16[32000, 4096][4096, 1]" = new_grad_tensor.copy_(l_param_grad, non_blocking = True);  copy_ = None
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/accelerator/cuda_accelerator.py:121 in current_stream, code: return torch.cuda.current_stream(device_index)
[rank4]:         current_stream_1 = torch.cuda.current_stream(None)
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/runtime/zero/stage3.py:1321 in torch_dynamo_resume_in___add_grad_to_ipg_bucket_at_1305, code: param.grad.record_stream(get_accelerator().current_stream())
[rank4]:         record_stream = l_param_grad.record_stream(current_stream_1);  current_stream_1 = record_stream = None
[rank4]:
[rank4]:          # File: /mnt/engines/deepspeed/deepspeed/runtime/zero/stage3.py:1322 in torch_dynamo_resume_in___add_grad_to_ipg_bucket_at_1305, code: param.grad.data = new_grad_tensor
[rank4]:         set_: "bf16[32000, 4096][4096, 1]" = torch_Tensor_set_(l_param_grad, new_grad_tensor);  l_param_grad = new_grad_tensor = None
[rank4]:         _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_);  set_ = _lower_version_count_by_1 = None
  1. Inductor cache: Stale AOTAutogradCache entries from prior runs (even with TORCHINDUCTOR_AUTOGRAD_CACHE=0).
    Clearing ~/.cache/torch/inductor should fix it.

I've removed /tmp/torchinductor_root/ (which is where inductor caches generated graph on my side), but the error persists.

Let me know if clearing the cache resolves the second one for you. BTW what's your PyTorch version/setup?

I'm struggling to find a working example for DeepSpeed + compiled autograd. If you have a working model at hand, would you please include that as a unit test in this PR as well so that we can test its benefits? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants